{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "获得训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import sys\n",
    "# 要导入代码的路径 ,utils无法导入的同学,添加上自己code的路径 ,项目代码结构 code/utils ....\n",
    "sys.path.append('../')  # 返回notebook的上一级目录\n",
    "# sys.path.append('E:\\GitHub\\QA-abstract-and-reasoning')  # 效果同上"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from gensim.models.word2vec import LineSentence\n",
    "from gensim.models import word2vec\n",
    "from utils.config import *\n",
    "from utils.loader import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 词向量的载入\n",
    "wv_model = word2vec.Word2Vec.load(WV_MODEL)\n",
    "vocab_index, index_vocab = load_vocab(VOCAB_INDEX)\n",
    "embedding_matrix = np.loadtxt(EMBEDDING_MATRIX)  # "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(82943, 6) \n",
      " (20000, 5)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>QID</th>\n",
       "      <th>Brand</th>\n",
       "      <th>Model</th>\n",
       "      <th>Question</th>\n",
       "      <th>Dialogue</th>\n",
       "      <th>Report</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Q1</td>\n",
       "      <td>奔驰</td>\n",
       "      <td>奔驰GL级</td>\n",
       "      <td>方向机 重 助力 泵 方向机 都 换</td>\n",
       "      <td>新 都 换 助力 泵 方向机 换 方向机 带 助力 重 这车 匹配 不 需要 更换 部件 问...</td>\n",
       "      <td>随时 联系</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Q2</td>\n",
       "      <td>奔驰</td>\n",
       "      <td>奔驰M级</td>\n",
       "      <td>奔驰 ML500 排气 凸轮轴 调节 错误</td>\n",
       "      <td>有没有 电脑 检测 故障 代码 有发 一下 发动机 之前 亮 故障 灯 显示 失火 有点 缺...</td>\n",
       "      <td>随时 联系</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Q3</td>\n",
       "      <td>宝马</td>\n",
       "      <td>宝马X1</td>\n",
       "      <td>2010 款 宝马X1 2011 年 出厂 20 排量 通用 6L45 变速箱 原地 换挡 ...</td>\n",
       "      <td>4 缸 自然 吸气 发动机 N46 先 挂 空档 再 挂 档 有没有 闯动 变速箱 油液 位...</td>\n",
       "      <td>行驶 没有 顿挫 感觉 原地 换挡 闯动 刹车 踩 重 没有 力 限制 作用 应该 没有 问题</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Q4</td>\n",
       "      <td>JEEP</td>\n",
       "      <td>牧马人</td>\n",
       "      <td>30V6 发动机 号 位置 照片 最好</td>\n",
       "      <td>右侧 排气管 上方 缸体 上 靠近 变速箱 是不是 号 不 先拓 下来 行车证 下 不是 有...</td>\n",
       "      <td>举起 车辆 左 前轮 缸体 上</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Q5</td>\n",
       "      <td>奔驰</td>\n",
       "      <td>奔驰C级</td>\n",
       "      <td>2012 款 奔驰 C180 维修保养 动力 值得 拥有</td>\n",
       "      <td>家庭 用车 入手 维修保养 费用 不高 12 年 180 市场价 钱 现在 想 媳妇 买 属...</td>\n",
       "      <td>家庭 用车 入手 维修保养 价格 还 车况 好 价格合理 入手</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  QID Brand  Model                                           Question  \\\n",
       "0  Q1    奔驰  奔驰GL级                                 方向机 重 助力 泵 方向机 都 换   \n",
       "1  Q2    奔驰   奔驰M级                              奔驰 ML500 排气 凸轮轴 调节 错误   \n",
       "2  Q3    宝马   宝马X1  2010 款 宝马X1 2011 年 出厂 20 排量 通用 6L45 变速箱 原地 换挡 ...   \n",
       "3  Q4  JEEP    牧马人                                30V6 发动机 号 位置 照片 最好   \n",
       "4  Q5    奔驰   奔驰C级                       2012 款 奔驰 C180 维修保养 动力 值得 拥有   \n",
       "\n",
       "                                            Dialogue  \\\n",
       "0  新 都 换 助力 泵 方向机 换 方向机 带 助力 重 这车 匹配 不 需要 更换 部件 问...   \n",
       "1  有没有 电脑 检测 故障 代码 有发 一下 发动机 之前 亮 故障 灯 显示 失火 有点 缺...   \n",
       "2  4 缸 自然 吸气 发动机 N46 先 挂 空档 再 挂 档 有没有 闯动 变速箱 油液 位...   \n",
       "3  右侧 排气管 上方 缸体 上 靠近 变速箱 是不是 号 不 先拓 下来 行车证 下 不是 有...   \n",
       "4  家庭 用车 入手 维修保养 费用 不高 12 年 180 市场价 钱 现在 想 媳妇 买 属...   \n",
       "\n",
       "                                            Report  \n",
       "0                                            随时 联系  \n",
       "1                                            随时 联系  \n",
       "2  行驶 没有 顿挫 感觉 原地 换挡 闯动 刹车 踩 重 没有 力 限制 作用 应该 没有 问题  \n",
       "3                                  举起 车辆 左 前轮 缸体 上  \n",
       "4                  家庭 用车 入手 维修保养 价格 还 车况 好 价格合理 入手  "
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 预处理数据载入\n",
    "train_seg = pd.read_csv(TRAIN_SEG).fillna(\"\")\n",
    "test_seg = pd.read_csv(TEST_SEG).fillna(\"\")\n",
    "print(train_seg.shape,\"\\n\",test_seg.shape)\n",
    "train_seg.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    方向机 重 助力 泵 方向机 都 换 新 都 换 助力 泵 方向机 换 方向机 带 助力 重...\n",
       "1    奔驰 ML500 排气 凸轮轴 调节 错误 有没有 电脑 检测 故障 代码 有发 一下 发动...\n",
       "2    2010 款 宝马X1 2011 年 出厂 20 排量 通用 6L45 变速箱 原地 换挡 ...\n",
       "3    30V6 发动机 号 位置 照片 最好 右侧 排气管 上方 缸体 上 靠近 变速箱 是不是 ...\n",
       "4    2012 款 奔驰 C180 维修保养 动力 值得 拥有 家庭 用车 入手 维修保养 费用 ...\n",
       "Name: X, dtype: object"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_seg['X'] = train_seg[['Question', 'Dialogue']].apply(lambda x: ' '.join(x), axis=1)\n",
    "test_seg['X'] = test_seg[['Question', 'Dialogue']].apply(lambda x: ' '.join(x), axis=1)\n",
    "train_seg['X'].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 句子长度不一样: 填充"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad(sentence, max_len, vocab):\n",
    "    \"\"\"\n",
    "    <START> <END> <PAD> <UNK>\n",
    "    \"\"\"\n",
    "\n",
    "    # 0.按空格统计切分出词\n",
    "    words = sentence.strip().split(' ')\n",
    "    # 1. [截取]规定长度的词数\n",
    "    words = words[:max_len]\n",
    "    # 2. 填充< unk > ,判断是否在vocab中, 不在填充 < unk >\n",
    "    sentence = [word if word in vocab else '<UNK>' for word in words]\n",
    "    # 3. 填充< start > < end >\n",
    "    sentence = ['<START>'] + sentence + ['<STOP>']\n",
    "    # 4. 判断长度，填充　< pad >\n",
    "    sentence = sentence + ['<PAD>'] * (max_len - len(words))\n",
    "    return ' '.join(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'方向机 重 助力 泵 方向机 都 换 新 都 换 助力 泵 方向机 换 方向机 带 助力 重 这车 匹配 不 需要 更换 部件 问题 跑 快 还好 点 倒车 重 很 非常 重 累人 觉得 车主 以前 没 重选 助理 泵 换 不行 放 机换 现在 还 不 知道 车主 解释'"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_seg['X'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<START> 方向机 重 助力 泵 方向机 都 换 新 都 换 助力 泵 方向机 换 方向机 带 助力 重 这车 匹配 不 需要 更换 部件 问题 跑 快 还好 点 倒车 重 很 非常 重 累人 觉得 车主 以前 没 <UNK> 助理 泵 换 不行 放 机换 现在 还 不 知道 车主 解释 <STOP> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>'"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pad(train_seg['X'][0], 60, vocab_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 获取适当的max_len"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最大长度值 = 每句话长度均值 + 2倍每句话长度标准差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_max_len(data):\n",
    "    \"\"\"\n",
    "    获得合适的最大长度值\n",
    "    :param data: 待统计的数据  train_df['Question']\n",
    "    :return: 最大长度值\n",
    "    \"\"\"\n",
    "    max_lens = data.apply(lambda x: x.count(' '))\n",
    "    return int(np.mean(max_lens) + 2 * np.std(max_lens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取输入数据 适当的最大长度\n",
    "train_x_max_len = get_max_len(train_seg['X'])\n",
    "test_x_max_len = get_max_len(test_seg['X'])\n",
    "\n",
    "x_max_len = max(train_x_max_len, test_x_max_len)\n",
    "\n",
    "# 获取标签数据 适当的最大长度\n",
    "train_y_max_len = get_max_len(train_seg['Report'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集x最大长度 247\n",
      "测试集x最大长度 258\n",
      "训练集y最大长度 31\n"
     ]
    }
   ],
   "source": [
    "print(\"训练集x最大长度 {}\".format(train_x_max_len))\n",
    "print(\"测试集x最大长度 {}\".format(test_x_max_len))\n",
    "print(\"训练集y最大长度 {}\".format(train_y_max_len))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练集X处理\n",
    "train_seg['X'] = train_seg['X'].apply(lambda x: pad(x, x_max_len, vocab_index))\n",
    "# 训练集Y处理\n",
    "train_seg['Y'] = train_seg['Report'].apply(lambda x: pad(x, train_y_max_len, vocab_index))\n",
    "# 测试集X处理\n",
    "test_seg['X'] = test_seg['X'].apply(lambda x: pad(x, x_max_len, vocab_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存中间结果数据\n",
    "train_seg['X'].to_csv(TRAIN_X_PAD, index=None, header=False)\n",
    "train_seg['Y'].to_csv(TRAIN_Y_PAD, index=None, header=False)\n",
    "test_seg['X'].to_csv(TEST_X_PAD, index=None, header=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 更新词表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start retrain w2v model\n",
      "1/3\n",
      "2/3\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(8690390, 26000000)"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print('start retrain w2v model')\n",
    "wv_model.build_vocab(LineSentence(TRAIN_X_PAD), update=True)\n",
    "wv_model.train(LineSentence(TRAIN_X_PAD), epochs=5, total_examples=wv_model.corpus_count)\n",
    "print('1/3')\n",
    "wv_model.build_vocab(LineSentence(TRAIN_Y_PAD), update=True)\n",
    "wv_model.train(LineSentence(TRAIN_Y_PAD), epochs=5, total_examples=wv_model.corpus_count)\n",
    "print('2/3')\n",
    "wv_model.build_vocab(LineSentence(TEST_X_PAD), update=True)\n",
    "wv_model.train(LineSentence(TEST_X_PAD), epochs=5, total_examples=wv_model.corpus_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存词向量模型\n",
    "wv_model.save(WV_MODEL_PAD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((32562, 300), (32566, 300))"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_matrix.shape, wv_model.wv.vectors.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 更新vocab和embedding\n",
    "vocab_index = {word: index for index, word in enumerate(wv_model.wv.index2word)}\n",
    "index_vocab = {index: word for index, word in enumerate(wv_model.wv.index2word)}\n",
    "embedding_matrix = wv_model.wv.vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    <START> 方向机 重 助力 泵 方向机 都 换 新 都 换 助力 泵 方向机 换 方向...\n",
       "1    <START> 奔驰 ML500 排气 凸轮轴 调节 错误 有没有 电脑 检测 故障 代码 ...\n",
       "2    <START> 2010 款 宝马X1 2011 年 出厂 20 排量 通用 <UNK> 变...\n",
       "3    <START> 30V6 发动机 号 位置 照片 最好 右侧 排气管 上方 缸体 上 靠近 ...\n",
       "4    <START> 2012 款 奔驰 C180 维修保养 动力 值得 拥有 家庭 用车 入手 ...\n",
       "Name: X, dtype: object"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_seg['X'].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 遇到未知词就填充unk的索引\n",
    "unk_index = vocab_index['<UNK>']\n",
    "def transform_data(sentence,vocab):\n",
    "    # 字符串切分成词\n",
    "    words=sentence.split(' ')\n",
    "    # 按照vocab的index进行转换\n",
    "    ids=[vocab[word] if word in vocab else unk_index for word in words]\n",
    "    return ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<START> 方向机 重 助力 泵 方向机 都 换 新 都 换 助力 泵 方向'"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_seg['X'][0][:40]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[32562, 403, 985, 247, 231, 403, 4, 10, 94, 4]"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform_data(train_seg['X'][0],vocab_index)[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将词转换成索引  [<START> 方向机 重 ...] -> [32800, 403, 986, 246, 231\n",
    "train_ids_x=train_seg['X'].apply(lambda x:transform_data(x,vocab_index))\n",
    "train_ids_y=train_seg['Y'].apply(lambda x:transform_data(x,vocab_index))\n",
    "test_ids_x=test_seg['X'].apply(lambda x:transform_data(x,vocab_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    [32562, 403, 985, 247, 231, 403, 4, 10, 94, 4,...\n",
       "1    [32562, 816, 26474, 388, 219, 361, 1610, 33, 2...\n",
       "2    [32562, 1445, 81, 5125, 1290, 64, 1138, 375, 6...\n",
       "3    [32562, 10244, 9, 307, 62, 523, 170, 933, 331,...\n",
       "4    [32562, 1348, 81, 816, 8963, 3282, 208, 2079, ...\n",
       "Name: X, dtype: object"
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_ids_x.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2.0]",
   "language": "python",
   "name": "conda-env-tf2.0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
